{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "# os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'\n",
    "# os.environ['CUDA_VISIBLE_DEVICES'] = '0'\n",
    "import tensorflow as tf\n",
    "tf.compat.v1.enable_eager_execution()\n",
    "\n",
    "import numpy as np\n",
    "import imageio\n",
    "import pprint\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import run_nerf\n",
    "import run_nerf_helpers\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load trained network weights\n",
    "Run `bash download_example_weights.sh` in the root directory if you need to download the Lego example weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "basedir = './logs'\n",
    "expname = 'lego_example'\n",
    "\n",
    "config = os.path.join(basedir, expname, 'config.txt')\n",
    "print('Args:')\n",
    "print(open(config, 'r').read())\n",
    "\n",
    "parser = run_nerf.config_parser()\n",
    "ft_str = '' \n",
    "ft_str = '--ft_path {}'.format(os.path.join(basedir, expname, 'model_200000.npy'))\n",
    "args = parser.parse_args('--config {} '.format(config) + ft_str)\n",
    "\n",
    "# Create nerf model\n",
    "_, render_kwargs_test, start, grad_vars, models = run_nerf.create_nerf(args)\n",
    "\n",
    "bds_dict = {\n",
    "    'near' : tf.cast(2., tf.float32),\n",
    "    'far' : tf.cast(6., tf.float32),\n",
    "}\n",
    "render_kwargs_test.update(bds_dict)\n",
    "\n",
    "print('Render kwargs:')\n",
    "pprint.pprint(render_kwargs_test)\n",
    "\n",
    "net_fn = render_kwargs_test['network_query_fn']\n",
    "print(net_fn)\n",
    "\n",
    "# Render an overhead view to check model was loaded correctly\n",
    "c2w = np.eye(4)[:3,:4].astype(np.float32) # identity pose matrix\n",
    "c2w[2,-1] = 4.\n",
    "H, W, focal = 800, 800, 1200.\n",
    "down = 8\n",
    "test = run_nerf.render(H//down, W//down, focal/down, c2w=c2w, **render_kwargs_test)\n",
    "img = np.clip(test[0],0,1)\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query network on dense 3d grid of points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 256\n",
    "t = np.linspace(-1.2, 1.2, N+1)\n",
    "\n",
    "query_pts = np.stack(np.meshgrid(t, t, t), -1).astype(np.float32)\n",
    "print(query_pts.shape)\n",
    "sh = query_pts.shape\n",
    "flat = query_pts.reshape([-1,3])\n",
    "\n",
    "\n",
    "def batchify(fn, chunk):\n",
    "    if chunk is None:\n",
    "        return fn\n",
    "    def ret(inputs):\n",
    "        return tf.concat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)\n",
    "    return ret\n",
    "    \n",
    "    \n",
    "fn = lambda i0, i1 : net_fn(flat[i0:i1,None,:], viewdirs=np.zeros_like(flat[i0:i1]), network_fn=render_kwargs_test['network_fine'])\n",
    "chunk = 1024*64\n",
    "raw = np.concatenate([fn(i, i+chunk).numpy() for i in range(0, flat.shape[0], chunk)], 0)\n",
    "raw = np.reshape(raw, list(sh[:-1]) + [-1])\n",
    "sigma = np.maximum(raw[...,-1], 0.)\n",
    "\n",
    "print(raw.shape)\n",
    "plt.hist(np.maximum(0,sigma.ravel()), log=True)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Marching cubes with [PyMCubes](https://github.com/pmneila/PyMCubes)\n",
    "Change `threshold` to use a different sigma threshold for the isosurface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import mcubes\n",
    "\n",
    "threshold = 50.\n",
    "print('fraction occupied', np.mean(sigma > threshold))\n",
    "vertices, triangles = mcubes.marching_cubes(sigma, threshold)\n",
    "print('done', vertices.shape, triangles.shape)\n",
    "\n",
    "### Uncomment to save out the mesh\n",
    "# mcubes.export_mesh(vertices, triangles, \"logs/lego_example/lego_{}.dae\".format(N), \"lego\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Live preview with [trimesh](https://github.com/mikedh/trimesh)\n",
    "Click and drag to change viewpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import trimesh\n",
    "\n",
    "mesh = trimesh.Trimesh(vertices / N - .5, triangles)\n",
    "mesh.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save out video with [pyrender](https://github.com/mmatl/pyrender)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"PYOPENGL_PLATFORM\"] = \"egl\"\n",
    "import pyrender\n",
    "from load_blender import pose_spherical"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scene = pyrender.Scene()\n",
    "scene.add(pyrender.Mesh.from_trimesh(mesh, smooth=False))\n",
    "\n",
    "# Set up the camera -- z-axis away from the scene, x-axis right, y-axis up\n",
    "camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)\n",
    "\n",
    "camera_pose = pose_spherical(-20., -40., 1.).numpy()\n",
    "nc = pyrender.Node(camera=camera, matrix=camera_pose)\n",
    "scene.add_node(nc)\n",
    "\n",
    "# Set up the light -- a point light in the same spot as the camera\n",
    "light = pyrender.PointLight(color=np.ones(3), intensity=4.0)\n",
    "nl = pyrender.Node(light=light, matrix=camera_pose)\n",
    "scene.add_node(nl)\n",
    "\n",
    "# Render the scene\n",
    "r = pyrender.OffscreenRenderer(640, 480)\n",
    "color, depth = r.render(scene)\n",
    "\n",
    "plt.imshow(color)\n",
    "plt.show()\n",
    "plt.imshow(depth)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imgs = []\n",
    "for th in np.linspace(0, 360., 120+1)[:-1]:\n",
    "    camera_pose = pose_spherical(th, -40., 1.).numpy()\n",
    "    scene.set_pose(nc, pose=camera_pose)\n",
    "    imgs.append(r.render(scene)[0])\n",
    "f = 'logs/lego_example/lego_mesh_turntable.mp4'\n",
    "imageio.mimwrite(f, imgs, fps=30)\n",
    "print('done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import HTML\n",
    "from base64 import b64encode\n",
    "mp4 = open(f,'rb').read()\n",
    "data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
    "HTML(\"\"\"\n",
    "<video width=400 controls autoplay loop>\n",
    "      <source src=\"%s\" type=\"video/mp4\">\n",
    "</video>\n",
    "\"\"\" % data_url)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
